{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regularisation in NNs¶"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Before we start doing anything, I think it's important to understand for NLP, this is the intuitive process on what we are trying to do when we are processing our data in the IMDB dataset:\n",
    "1. Tokenization: break sentence into individual words\n",
    "    - Before: `\"PyTorch seems really easy to use!\"`\n",
    "    - After: `[\"PyTorch\", \"seems\", \"really\", \"easy\", \"to\", \"use\", \"!\"]`\n",
    "2. Building vocabulary: build an index of words associated with unique numbers\n",
    "    - Before: `[\"PyTorch\", \"seems\", \"really\", \"easy\", \"to\", \"use\", \"!\"]`\n",
    "    - After: `{\"Pytorch: 0, \"seems\": 1, \"really\": 2, ...}`\n",
    "3. Convert to numerals: map words to unique numbers (indices)\n",
    "    - Before: `{\"Pytorch: 0, \"seems\": 1, \"really\": 2, ...}`\n",
    "    - After: `[0, 1, 2, ...]`\n",
    "4. Embedding look-up: map sentences (indices now) to fixed matrices\n",
    "    - ```[[0.1, 0.4, 0.3],\n",
    "       [0.8, 0.1, 0.5],\n",
    "       ...]```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Critical plotting imports\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "# PyTorch imports\n",
    "from torchtext import data, datasets\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "# Checking for iterable objects\n",
    "import collections\n",
    "import random"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set seed\n",
    "torch.manual_seed(1337)\n",
    "if torch.cuda.is_available():\n",
    "    torch.cuda.manual_seed_all(1337)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set plotting style\n",
    "plt.style.use(('dark_background', 'bmh'))\n",
    "plt.rc('axes', facecolor='none')\n",
    "plt.rc('figure', figsize=(16, 4))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create instances of fields\n",
    "# The important field here is fix_length: all examples using this field will be padded to, or None for flexible sequence lengths\n",
    "# We are fixing this because we will be using a FNN not an LSTM/RNN/GRU where we can go through uneven sequence lengths\n",
    "max_len = 80\n",
    "text = data.Field(sequential=True, fix_length=max_len, batch_first=True, lower=True, dtype=torch.long)\n",
    "label = data.LabelField(sequential=False, dtype=torch.float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calling splits() class method of datasets.IMDB to return a torchtext.data.Dataset object\n",
    "datasets.IMDB.download('./')\n",
    "ds_train, ds_test = datasets.IMDB.splits(text, label, path='./imdb/aclImdb/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Training and test set each 25k samples\n",
    "# 2 fields due to the way we split above\n",
    "print('train : ', len(ds_train))\n",
    "print('test : ', len(ds_test))\n",
    "print('train.fields :', ds_train.fields)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get validation set\n",
    "seed_num = 1337\n",
    "ds_train, ds_valid = ds_train.split(random_state=random.seed(seed_num))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Now we've training, validation and test set\n",
    "print('train : ', len(ds_train))\n",
    "print('valid : ', len(ds_valid))\n",
    "print('valid : ', len(ds_test))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build vocabulary\n",
    "# num_words = 25000\n",
    "num_words = 1000\n",
    "text.build_vocab(ds_train, max_size=num_words)\n",
    "label.build_vocab(ds_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print vocab size\n",
    "print('Vocabulary size: {}'.format(len(text.vocab)))\n",
    "print('Label size: {}'.format(len(label.vocab)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print most common vocabulary text\n",
    "most_common_samples = 10\n",
    "print(text.vocab.freqs.most_common(most_common_samples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print most common labels\n",
    "print(label.vocab.freqs.most_common())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample 0 label\n",
    "ds_train[0].label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Sample 0 text: broken down into individual portions\n",
    "ds_train[0].text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Sample 0 text: human readeable sample\n",
    "def show_text(sample):\n",
    "    print(' '.join(word for word in sample))\n",
    "    \n",
    "show_text(ds_train[0].text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create and iterable object for our training, validation and testing datasets\n",
    "# Batches examples of similar lengths together that minimizes amount of padding needed\n",
    "batch_size = 64  # Change batch size from 1 to bigger number once explanation is done\n",
    "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n",
    "    (ds_train, ds_valid, ds_test), batch_size=batch_size, sort_key=lambda x: len(x.text), repeat=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if iterator above is an iterable which should show True\n",
    "isinstance(train_loader, collections.Iterable)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# What's inside this iteratable object? Our text and label although now everything is in machine format (not \"words\") but in numbers!\n",
    "# The text we saw above becomes a matrix of size 1 x 80 represented by the fixed length we defined before that\n",
    "list(train_loader)[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Alternative to above, this is much faster but the above code is easy to understand and implement\n",
    "next(train_loader.__iter__())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batch = next(train_loader.__iter__())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# What methods can we call on this batch object? Text and label\n",
    "test_batch.fields"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's break this down to check what's in a batch\n",
    "test_batch.text"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 1 comment per batch, each comment is limited to a size of 80 as we've defined\n",
    "test_batch.text.size()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_batch.label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Extremely weird problem in torchtext where BucketIterator returns a Batch object versus just a simple tuple of tensors containing our text index and labels\n",
    "# So let's fix this with a new class FixBatchGenerator\n",
    "\n",
    "class FixBatchGenerator:\n",
    "    def __init__(self, dl, x_field, y_field):\n",
    "        self.dl, self.x_field, self.y_field = dl, x_field, y_field\n",
    "        \n",
    "    def __len__(self):\n",
    "        return len(self.dl)\n",
    "    \n",
    "    def __iter__(self):\n",
    "        for batch in self.dl:\n",
    "            X = getattr(batch, self.x_field)\n",
    "            y = getattr(batch, self.y_field)\n",
    "            yield (X,y)\n",
    "            \n",
    "train_loader, valid_loader, test_loader = FixBatchGenerator(train_loader, 'text', 'label'), FixBatchGenerator(valid_loader, 'text', 'label'), FixBatchGenerator(test_loader, 'text', 'label')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Text index\n",
    "print(next(train_loader.__iter__())[0])\n",
    "\n",
    "# Text label\n",
    "print(next(train_loader.__iter__())[1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class FeedforwardNeuralNetModel(nn.Module):\n",
    "    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):\n",
    "        super(FeedforwardNeuralNetModel, self).__init__()\n",
    "        # Embedding layer\n",
    "        self.embedding = nn.Embedding(input_dim, embedding_dim)\n",
    "        \n",
    "        # Linear function\n",
    "        self.fc1 = nn.Linear(embedding_dim*embedding_dim, hidden_dim) \n",
    "\n",
    "        # Linear function (readout)\n",
    "        self.fc2 = nn.Linear(hidden_dim, output_dim)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        # Embedding\n",
    "        embedded = self.embedding(x)\n",
    "        embedded = embedded.view(-1, embedding_dim*embedding_dim)\n",
    "        # Linear function\n",
    "        out = self.fc1(embedded)\n",
    "\n",
    "        # Non-linearity\n",
    "        out = torch.relu(out)\n",
    "        \n",
    "        # Toggle 3: Dropout\n",
    "        # out = torch.dropout(out, 0.8)\n",
    "\n",
    "        # Linear function (readout)\n",
    "        # Take note here use a final sigmoid function so your loss should not go through sigmoid again.\n",
    "        # BCELoss is the right class to use as it doesn't pass your output through a sigmoid function again.\n",
    "        # In multi-class problems you're used to softmax which can be simplified to a logistic,\n",
    "        # function when you have a two-class problem.\n",
    "        out = self.fc2(out)\n",
    "        out = torch.sigmoid(out)\n",
    "    \n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dim = num_words + 2\n",
    "embedding_dim = max_len\n",
    "hidden_dim = 32\n",
    "output_dim = 1\n",
    "\n",
    "# Instantiate model class and assign to object\n",
    "model = FeedforwardNeuralNetModel(input_dim, embedding_dim, hidden_dim, output_dim)\n",
    "\n",
    "# Push model to CUDA device if available\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model.to(device)\n",
    "\n",
    "# Loss function\n",
    "criterion = nn.BCELoss()\n",
    "\n",
    "# Optimizer\n",
    "# Toggle 2: L2 Norm option - this is called weight decay\n",
    "# optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Number of groups of parameters\n",
    "print('Number of groups of parameters {}'.format(len(list(model.parameters()))))\n",
    "print('-'*50)\n",
    "# Print parameters\n",
    "for i in range(len(list(model.parameters()))):\n",
    "    print(list(model.parameters())[i].size())\n",
    "print('-'*50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "iter = 0\n",
    "num_epochs = 10\n",
    "history_train_acc, history_val_acc, history_train_loss, history_val_loss = [], [], [], []\n",
    "best_accuracy = 0\n",
    "for epoch in range(num_epochs):\n",
    "#     print('-'*50)\n",
    "    for i, (samples, labels) in enumerate(train_loader):\n",
    "        # Training mode\n",
    "        model.train()\n",
    "        \n",
    "        # Load samples\n",
    "        samples = samples.view(-1, max_len).to(device)\n",
    "        labels = labels.view(-1, 1).to(device)\n",
    "\n",
    "        # Clear gradients w.r.t. parameters\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Forward pass to get output/logits\n",
    "        outputs = model(samples)\n",
    "\n",
    "        # Calculate Loss: softmax --> cross entropy loss\n",
    "        loss = criterion(outputs, labels)\n",
    "        \n",
    "        # Toggle 1: L1 norm, add to original loss\n",
    "        # fc1_params = torch.cat([x.view(-1) for x in model.fc1.parameters()])\n",
    "        # loss += 0.001 * torch.norm(fc1_params, 1)\n",
    "    \n",
    "        # Getting gradients w.r.t. parameters\n",
    "        loss.backward()\n",
    "\n",
    "        # Updating parameters\n",
    "        optimizer.step()\n",
    "\n",
    "        iter += 1\n",
    "\n",
    "        if iter % 100 == 0:\n",
    "            # Get training statistics\n",
    "            train_loss = loss.data.item()\n",
    "            \n",
    "            # Testing mode\n",
    "            model.eval()\n",
    "            # Calculate Accuracy         \n",
    "            correct = 0\n",
    "            total = 0\n",
    "            # Iterate through test dataset\n",
    "            for samples, labels in valid_loader:\n",
    "                # Load samples\n",
    "                samples = samples.view(-1, max_len).to(device)\n",
    "                labels = labels.view(-1).to(device)\n",
    "\n",
    "                # Forward pass only to get logits/output\n",
    "                outputs = model(samples)\n",
    "                \n",
    "                # Val loss\n",
    "                val_loss = criterion(outputs.view(-1, 1), labels.view(-1, 1))\n",
    "                \n",
    "                # We use a threshold to define. \n",
    "                # There is another way to do this with one-hot label. Feel free to explore and understand what are the pros/cons of each.\n",
    "                # This opens up a whole topic on why it becomes problematic when we expand beyond 2 class to 10 classes.\n",
    "                # Why do we encode? Why can't we do 0, 1, 2, 3, 4 etc. without one-hot encoding?\n",
    "                predicted = outputs.ge(0.5).view(-1)\n",
    "\n",
    "                # Total number of labels\n",
    "                total += labels.size(0)\n",
    "\n",
    "                # Total correct predictions\n",
    "                correct += (predicted.type(torch.FloatTensor).cpu() == labels.type(torch.FloatTensor)).sum().item()\n",
    "                # correct = (predicted == labels.byte()).int().sum().item()\n",
    "            \n",
    "            accuracy = 100. * correct / total\n",
    "        \n",
    "            # Print Loss\n",
    "            print('Iter: {} | Train Loss: {} | Val Loss: {} | Val Accuracy: {}'.format(iter, train_loss, val_loss.item(), round(accuracy, 2)))\n",
    "            \n",
    "            # Append to history\n",
    "            history_val_loss.append(val_loss.data.item())\n",
    "            history_val_acc.append(round(accuracy, 2))\n",
    "            history_train_loss.append(train_loss)\n",
    "            \n",
    "            # Save model when accuracy beats best accuracy\n",
    "            if accuracy > best_accuracy:\n",
    "                best_accuracy = accuracy\n",
    "                # We can load this best model on the validation set later\n",
    "                torch.save(model.state_dict(), 'best_model.pth')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting loss graph\n",
    "plt.plot(history_train_loss, label='Train')\n",
    "plt.plot(history_val_loss, label='Validation')\n",
    "plt.title('Loss Graph')\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plotting validation accuracy graph\n",
    "plt.plot(history_val_acc)\n",
    "plt.title('Validation Accuracy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = torch.Tensor().to(device)\n",
    "for param_group in list(model.parameters()):\n",
    "    weights = torch.cat((param_group.view(-1), weights))\n",
    "    print(param_group.size())\n",
    "    \n",
    "# Toggle 0: No regularization\n",
    "weights_nothing = weights.cpu().detach().numpy()\n",
    "\n",
    "# Toggle 1: L1 norm on FC1\n",
    "# weights_L1 = weights.detach().numpy()\n",
    "\n",
    "# Toggle 2: L2 norm\n",
    "# weights_L2 = weights.detach().numpy()\n",
    "\n",
    "# Toggle 3: dropout\n",
    "# weights_dropout = weights.detach().numpy()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.hist(weights_L1.reshape(-1), range=(-.5, .5), bins=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plt.hist(weights_nothing.reshape(-1), range=(-.5, .5), bins=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show weight distribution\n",
    "plt.hist((\n",
    "    weights_nothing.reshape(-1),\n",
    "    weights_L1.reshape(-1),\n",
    "    weights_L2.reshape(-1),\n",
    "), 49, range=(-.5, .5), label=(\n",
    "    'No-reg',\n",
    "    'L1',\n",
    "    'L2',\n",
    "))\n",
    "plt.legend();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:dl-minicourse] *",
   "language": "python",
   "name": "conda-env-dl-minicourse-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
